{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 65,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch \n",
    "import torch.nn as nn\n",
    "import math\n",
    "import numpy as np\n",
    "import pandas\n",
    "import matplotlib.pyplot as plt"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 基本方法"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "False\n",
      "True\n"
     ]
    }
   ],
   "source": [
    "#如何判断一个对象是否为Tensor张量\n",
    "obj = np.arange(1,10)\n",
    "print(torch.is_tensor(obj))\n",
    "obj1 = torch.Tensor(10)\n",
    "print(torch.is_tensor(obj1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.float64\n"
     ]
    }
   ],
   "source": [
    "#如何全局设置Tensor数据类型？\n",
    "torch.set_default_tensor_type(torch.DoubleTensor)\n",
    "print(torch.Tensor(2).dtype)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "False\n",
      "True\n"
     ]
    }
   ],
   "source": [
    "#如何判断一个对象是否为Pytorch Storage对象\n",
    "#torch.Storage is a contiguous, one-dimensional array of a single data type.\n",
    "print(torch.is_storage(obj))\n",
    "storage = torch.DoubleStorage([2,3])\n",
    "print(torch.is_storage(storage))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "12\n"
     ]
    }
   ],
   "source": [
    "#如何获取Tensor中元素的个数\n",
    "a = torch.Tensor(3,4)\n",
    "print(torch.numel(a))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 52,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[ 6.95231e-310,  6.95231e-310,  6.95229e-310,  4.52355e+257],\n",
      "        [ 1.28626e+248,  2.43209e-152,   1.53250e-94,  2.14728e+243],\n",
      "        [  0.00000e+00,  4.77497e+180,  6.95231e-310,  6.95231e-310],\n",
      "        [ 6.95229e-310,   2.57664e+97,   1.72733e+97,  2.25153e-310],\n",
      "        [ 6.84979e+180, -3.50707e-311,   0.00000e+00,  2.04737e+190],\n",
      "        [ 6.95231e-310,  6.95231e-310,  6.95229e-310,  9.82202e+252],\n",
      "        [ 6.16779e+223,  2.41318e+185,  1.67772e+243,   1.52781e-94],\n",
      "        [  0.00000e+00,  1.99502e+161,  6.95231e-310,  6.95231e-310],\n",
      "        [ 6.95229e-310,  4.24336e+175,  1.06025e-153,  7.36507e+228],\n",
      "        [ 4.67352e+257,  1.30489e+180,   0.00000e+00,   0.00000e+00]])\n"
     ]
    }
   ],
   "source": [
    "#如何设置打印选项？\n",
    "#precision=None, threshold=None, edgeitems=None, linewidth=None, profile=None\n",
    "torch.set_printoptions(precision=5,threshold=100,linewidth=100,edgeitems=4)\n",
    "print(torch.DoubleTensor(10,4))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 99,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[1., 0., 0.],\n",
      "        [0., 1., 0.],\n",
      "        [0., 0., 1.]])\n"
     ]
    }
   ],
   "source": [
    "#如何创建单位矩阵\n",
    "a = torch.eye(3,3)\n",
    "print(a)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 81,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([ 1,  3,  5,  7,  9, 11, 13, 15, 17, 19], dtype=torch.int32)\n",
      "tensor([[6.95232e-310, 6.95232e-310, 6.95232e-310],\n",
      "        [6.95232e-310, 6.95232e-310, 6.95232e-310]])\n"
     ]
    }
   ],
   "source": [
    "#如何从numpy多维数组创建Tensor张量\n",
    "a = np.arange(1,20,2)\n",
    "print(torch.from_numpy(a))\n",
    "b = np.ndarray((2,3))\n",
    "print(torch.from_numpy(b))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 77,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([1., 3., 5., 7., 9.], requires_grad=True)\n"
     ]
    }
   ],
   "source": [
    "#如何创建等差数列\n",
    "b =  torch.Tensor()\n",
    "print(torch.linspace(1,9,5,requires_grad=True))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 84,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([1.00000e+01, 1.00000e+02, 1.00000e+03, 1.00000e+04, 1.00000e+05], requires_grad=True)\n"
     ]
    }
   ],
   "source": [
    "#和linspace类似的logspace\n",
    "print(torch.logspace(1,5,5,requires_grad=True))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 86,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[1., 1., 1.],\n",
      "        [1., 1., 1.]])\n"
     ]
    }
   ],
   "source": [
    "#如何创建元素全部为1的矩阵\n",
    "print(torch.ones(2,3))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 93,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[0.12264, 0.37531, 0.02611],\n",
      "        [0.38838, 0.80195, 0.34111]], requires_grad=True)\n"
     ]
    }
   ],
   "source": [
    "#如何创建0,1均匀的随机矩阵，形状可以指定\n",
    "print(torch.rand(2,3,requires_grad=True))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 92,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[ 0.92620, -2.99641,  0.71482],\n",
      "        [-1.12788, -0.43067,  0.58641]])\n"
     ]
    }
   ],
   "source": [
    "#如何创建标准正太分布随机矩阵？形状可以指定\n",
    "print(torch.randn(2,3))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 96,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[2 6 1 7 0 5 4 3 9 8]\n",
      "tensor([2, 3, 8, 7, 9, 4, 6, 5, 1, 0])\n"
     ]
    }
   ],
   "source": [
    "#如何创建随机整数序列，如同numpy.random.permutation?\n",
    "print(numpy.random.permutation(10))\n",
    "print(torch.randperm(10))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 98,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[1 4 7]\n",
      "tensor([1, 4, 7])\n"
     ]
    }
   ],
   "source": [
    "#如何创建一个列表如同numpy中的arange?\n",
    "print(np.arange(1,10,3))\n",
    "print(torch.arange(1,10,3))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 100,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[0., 0.],\n",
      "        [0., 0.]])\n"
     ]
    }
   ],
   "source": [
    "#如何创建一个全0的矩阵？\n",
    "print(torch.zeros(2,2))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 索引、切片、拼接及换位方法"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 218,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[1., 1., 1.],\n",
      "        [1., 1., 1.],\n",
      "        [1., 1., 1.],\n",
      "        [1., 1., 1.]])\n",
      "tensor([[1., 1., 1., 1., 1., 1., 1., 1., 1.],\n",
      "        [1., 1., 1., 1., 1., 1., 1., 1., 1.]])\n",
      "tensor([[[1., 1.],\n",
      "         [1., 1.],\n",
      "         [1., 1.]],\n",
      "\n",
      "        [[1., 1.],\n",
      "         [1., 1.],\n",
      "         [1., 1.]]])\n"
     ]
    }
   ],
   "source": [
    "#如何将多个Tensor按照某一维度拼接起来？\n",
    "tensor = torch.ones(2,3)\n",
    "print(torch.cat([tensor,tensor]))\n",
    "print(torch.cat([tensor,tensor,tensor],dim=1))\n",
    "#stack方法也可以进行tensor的拼接\n",
    "print(torch.stack([tensor,tensor],dim=2))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 145,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(tensor([[1., 1.],\n",
      "        [1., 1.]]), tensor([[1., 1.],\n",
      "        [1., 1.]]), tensor([[1., 1.],\n",
      "        [1., 1.]]), tensor([[1., 1.],\n",
      "        [1., 1.]]), tensor([[1., 1.],\n",
      "        [1., 1.]]))\n",
      "(tensor([[1., 1.],\n",
      "        [1., 1.]]), tensor([[1., 1.],\n",
      "        [1., 1.]]), tensor([[1., 1.],\n",
      "        [1., 1.]]), tensor([[1., 1.],\n",
      "        [1., 1.]]), tensor([[1., 1.],\n",
      "        [1., 1.]]))\n"
     ]
    }
   ],
   "source": [
    "#如何将一个Tensor按照指定维度切片成n个分片？\n",
    "#按照第二维度将tensor切分为5个tensor\n",
    "x = torch.ones(2,10)\n",
    "print(torch.chunk(x,5,dim=1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 115,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[33., 66.],\n",
       "        [88., 99.]])"
      ]
     },
     "execution_count": 115,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#如何按照索引进行元素的聚合？\n",
    "#如何将元素33,66,88,99聚合在一起?\n",
    "#第一个参数是源tensor,第二个参数为维度，第三个参数为索引\n",
    "x = torch.Tensor([[33,66,9],[1,99,88]])\n",
    "torch.gather(x,1,torch.LongTensor([[0,1],[2,1]]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 125,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[-0.26020,  0.48570, -0.49547,  1.57589, -0.17201,  0.11631,  0.48189],\n",
      "        [-0.35056, -0.31785,  0.14796, -2.50720,  0.99119,  0.89042, -1.62519]])\n",
      "tensor([[ 0.48570,  1.57589,  0.11631],\n",
      "        [-0.31785, -2.50720,  0.89042]])\n",
      "tensor([[-0.35056, -0.31785,  0.14796, -2.50720,  0.99119,  0.89042, -1.62519]])\n"
     ]
    }
   ],
   "source": [
    "#如何按照索引选择目标数据？\n",
    "#如何取出第2、4、6列数据，返回新的tensor\n",
    "x = torch.randn(2,7)\n",
    "print(x)\n",
    "#第一个参数为源tensor,第二个参数为维度，第三个参数为该维度上的索引\n",
    "y = torch.index_select(x,1,torch.LongTensor([1,3,5]))\n",
    "print(y)\n",
    "z = torch.index_select(x,0,torch.LongTensor([1]))\n",
    "print(z)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 131,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[0.18211, 0.72952, 0.85274, 0.14525],\n",
      "        [0.00986, 0.32629, 0.65850, 0.44055]])\n",
      "tensor([[0, 1, 1, 0],\n",
      "        [0, 0, 1, 0]], dtype=torch.uint8)\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "tensor([0.72952, 0.85274, 0.65850])"
      ]
     },
     "execution_count": 131,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#如何选出满足条件的所有元素\n",
    "#masked_select方法，返回mask标志为1的所有元素组成的1维Tensor\n",
    "x = torch.rand(2,4)\n",
    "print(x)\n",
    "mask = x.ge(0.5)\n",
    "print(mask)\n",
    "torch.masked_select(x,mask)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 141,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[[ 0.00000, 11.00000],\n",
      "         [66.00000, 88.00000]],\n",
      "\n",
      "        [[22.00000, 33.00000],\n",
      "         [ 0.00000,  0.10000]]])\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "tensor([[0, 0, 1],\n",
       "        [0, 1, 0],\n",
       "        [0, 1, 1],\n",
       "        [1, 0, 0],\n",
       "        [1, 0, 1],\n",
       "        [1, 1, 1]])"
      ]
     },
     "execution_count": 141,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#如何找出矩阵中非零元素的索引？\n",
    "#nonzero方法返回非零的索引，结果tensor为二维tensor，行数等于源tensor中非零元素个数，列数等于源tensor的维度\n",
    "x = torch.Tensor([[[0.0,11],[66,88]],[[22,33],[0.0,0.1]]])\n",
    "print(x)\n",
    "torch.nonzero(x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 165,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],\n",
      "        [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]])\n",
      "(tensor([[1., 1., 1., 1., 1.],\n",
      "        [1., 1., 1., 1., 1.]]), tensor([[1., 1., 1., 1., 1.],\n",
      "        [1., 1., 1., 1., 1.]]))\n",
      "(tensor([[1.],\n",
      "        [1.]]), tensor([[1., 1.],\n",
      "        [1., 1.]]), tensor([[1., 1., 1.],\n",
      "        [1., 1., 1.]]), tensor([[1., 1., 1., 1.],\n",
      "        [1., 1., 1., 1.]]))\n"
     ]
    }
   ],
   "source": [
    "#如何将输入张量分割成相大小的chunks?\n",
    "x = torch.ones(2,10)\n",
    "print(x)\n",
    "#每个块5个长度的元素，dim表示按第二维度\n",
    "print(torch.split(x,5,dim=1))\n",
    "#也可指定一个划分列表，依次表示有1,2,3,4个长度\n",
    "print(torch.split(x,[1,2,3,4],dim=1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 236,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[1., 2., 3., 4., 5., 6.]])\n",
      "tensor([[1.],\n",
      "        [2.],\n",
      "        [3.],\n",
      "        [4.],\n",
      "        [5.],\n",
      "        [6.]])\n",
      "tensor([[1., 2., 3., 4., 5., 6.]])\n"
     ]
    }
   ],
   "source": [
    "#如何给矩阵增加维度\n",
    "x = torch.Tensor([1,2,3,4,5,6])\n",
    "#dim关键字参数指定在第几维度增加`[]`\n",
    "y = x.unsqueeze(dim=0)\n",
    "print(y)\n",
    "z = x.unsqueeze(dim=1)\n",
    "print(z)\n",
    "print(torch.unsqueeze(x,dim=0))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 234,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[[ 0.,  2.,  3.,  4.],\n",
      "         [22., 33., 44., 55.]]]) torch.Size([1, 2, 4])\n",
      "tensor([[ 0.,  2.,  3.,  4.],\n",
      "        [22., 33., 44., 55.]]) torch.Size([2, 4])\n",
      "tensor([[ 0.,  2.,  3.,  4.],\n",
      "        [22., 33., 44., 55.]]) torch.Size([2, 4])\n"
     ]
    }
   ],
   "source": [
    "#如何去掉`[]`降低维度？去维度为1的\n",
    "x = torch.Tensor([[[0,2,3,4],[22,33,44,55]]])\n",
    "print(x,x.shape)\n",
    "#dim关键字参数指定在地接个维度是`1`,squeeze将去除掉这个维度\n",
    "y = torch.squeeze(x,dim=0)\n",
    "print(y,y.shape)\n",
    "z = x.squeeze(dim=0)\n",
    "print(z,z.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 238,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[ 1.18556],\n",
      "        [-0.02856]]) torch.Size([2, 1])\n",
      "tensor([[ 1.18556, -0.02856]]) torch.Size([1, 2])\n",
      "tensor([[ 1.18556, -0.02856]]) torch.Size([1, 2])\n",
      "tensor([[ 1.18556, -0.02856]])\n",
      "tensor([[ 1.18556, -0.02856]])\n"
     ]
    }
   ],
   "source": [
    "#如何实现tensor维度之间的转置\n",
    "#tensor自身的t和transpose方法跟torch上的t和transpose方法功能类似\n",
    "x = torch.randn(2,1)\n",
    "print(x,x.shape)\n",
    "y = torch.t(x)\n",
    "print(y,y.shape)\n",
    "z = torch.transpose(x,1,0)\n",
    "print(z,z.shape)\n",
    "print(x.t())\n",
    "print(x.transpose(1,0))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 248,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[[0.90266, 0.19014],\n",
      "         [0.64815, 0.87569]],\n",
      "\n",
      "        [[0.94966, 0.53289],\n",
      "         [0.94499, 0.76736]]])\n",
      "(tensor([[0.90266, 0.19014],\n",
      "        [0.94966, 0.53289]]), tensor([[0.64815, 0.87569],\n",
      "        [0.94499, 0.76736]]))\n"
     ]
    }
   ],
   "source": [
    "#unbind删除某一维度之后，返回所有切片组成的元祖列表\n",
    "x = torch.rand(2,2,2)\n",
    "print(x)\n",
    "print(torch.unbind(x,dim=1))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 随机抽样"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 280,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "seed:123\n",
      "state:tensor([123,   0,   0,   0,  ...,   0,   0,   0,   0], dtype=torch.uint8) 5048\n"
     ]
    }
   ],
   "source": [
    "#手动设置随机种子\n",
    "torch.manual_seed(123)\n",
    "#如果没有手动设置随机种子，则返回系统生成的随机种子；否则返回手动设置的随机种子\n",
    "seed = torch.initial_seed()\n",
    "print(\"seed:{}\".format(seed))\n",
    "#返回随机生成器状态\n",
    "state = torch.get_rng_state()\n",
    "print(\"state:{}\".format(state),len(state))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 304,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[0.36890, 0.01337, 0.59178],\n",
      "        [0.09264, 0.47245, 0.52203],\n",
      "        [0.60508, 0.53130, 0.94855]])\n",
      "tensor([[1., 0., 1.],\n",
      "        [0., 1., 1.],\n",
      "        [1., 0., 1.]])\n"
     ]
    }
   ],
   "source": [
    "#伯努利分布，结果只有0和1，第一个参数是概率p,并且0<=p<=1,\n",
    "torch.manual_seed(123)\n",
    "a = torch.rand(3, 3)\n",
    "print(a)\n",
    "b = torch.bernoulli(a)\n",
    "print(b)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 361,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[0, 1, 2],\n",
      "        [3, 0, 2]])\n",
      "tensor([[0, 0, 1, 0, 0, 1, 3, 0, 0, 2],\n",
      "        [3, 3, 3, 3, 3, 3, 0, 3, 3, 0]])\n",
      "tensor([[0, 0, 1, 2, 2, 0, 1, 2, 0, 0, 1, 0, 0, 1, 1],\n",
      "        [3, 3, 3, 2, 0, 0, 3, 0, 0, 0, 2, 2, 3, 3, 3]])\n",
      "tensor([[1, 0, 2, 3],\n",
      "        [0, 3, 1, 2]])\n"
     ]
    }
   ],
   "source": [
    "#多项式分布抽样\n",
    "#torch.multimomial第一个参数为多项式权重，可以是向量也可以为矩阵，有权重决定对`下标`的抽样\n",
    "#为向量：replacement表示是否有放回的抽样，如果为True，结果行数为1，列数有num_samples指定；否则行数为1，列数<=权重weights长度\n",
    "weights1 = torch.Tensor([20, 10, 3, 2])\n",
    "a = torch.multinomial(weights,num_samples=3,replacement=False)\n",
    "b = torch.multinomial(weights,num_samples=10,replacement=True)\n",
    "print(a)\n",
    "print(b)\n",
    "#为矩阵：replacement表示是否有放回的抽样，如果为True，结果行数为weights行数，列数由num_samples指定；\n",
    "#否则行数为weights行数，列数<=权重weights每一行长度\n",
    "weights2 = torch.Tensor([[20, 10, 3, 2],[30,4,5,60]])\n",
    "c = torch.multinomial(weights,num_samples=15,replacement=True)\n",
    "d = torch.multinomial(weights,num_samples=4,replacement=False)\n",
    "print(c)\n",
    "print(d)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([ 0.4248,  0.3964,  0.4494,  ...,  0.9055,  0.2543, -0.3518])\n",
      "tensor([ 0.5807,  0.4270, -0.2506,  0.2640,  0.1797,  0.4377,  1.1154,  0.6373,\n",
      "         0.4479])\n",
      "tensor([0.0370, 0.2305, 0.3554, 0.6012, 0.6571, 1.5476, 0.1952, 0.6846, 2.8770])\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXYAAAD8CAYAAABjAo9vAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAAIABJREFUeJzt3XeYVNX5B/DvuwWW3nbpZelV6tJZQEFEEI1dExsxGEWjiUaD2LBEiQaNJcaQ6M+EgAajKMLSBQGVjvRlaUuHXUA6u2w5vz+m7MzsnTt3Zu7cOzP7/TwPzzPlzr1nL7vvvfOec94jSikQEVH8SLC7AUREZC4GdiKiOMPATkQUZxjYiYjiDAM7EVGcYWAnIoozDOxERHGGgZ2IKM4wsBMRxZkkOw6ampqq0tPT7Tg0EVHMWr9+/QmlVFqg7WwJ7Onp6Vi3bp0dhyYiilkist/IdkzFEBHFGQZ2IqI4w8BORBRnGNiJiOIMAzsRUZxhYCciijMM7EREcYaBnYhi1uq9J7Hr+Dm7mxF1bJmgRERkhtunrgIA5E4ebXNLogvv2ImI4gwDOxFRnGFgJyKKMwzsREQBLM/Jx7Qfcu1uhmHsPCUiCuCej9YAAO7un25vQwziHTsRkY+zBUV48ettKCwusbspIWFgJyLy8daiHPzfd7mYue6Q3U0JCQM7EZGP4hIFAFBK2dyS0DCwExHFGQZ2IqI4w8BORBRnGNiJiOIMAzsRkR8x2nfKwE5EVFrqHcFFbGqISRjYiahCyz52Fq0mZmHJjuN2N8U0YQd2EUkRkTUisklEtonIi2Y0jIjIChsPnAYALNoeP4HdjFoxhQCuUkqdF5FkACtFZJ5SapUJ+yYioiCFHdiVY2rWeefTZOe/GO1yICLyduxMgd1NCJopOXYRSRSRHwHkAViklFptxn6JiKziOQLG8/Hh05esb0yYTAnsSqkSpVR3AE0B9BGRLr7biMgDIrJORNbl5+ebcVgiorDF+AAYTaaOilFKnQawDMBIjfemKqUylFIZaWlpZh6WiExw5lIRPl8fm9UMzVbhhzuKSJqI1HY+rgJgOIDscPdLRNZ68rNNeOKzTdhx9KzdTYkysddlaMaomEYA/iUiiXBcKGYqpeaYsF8istDxc4UAgIKi2FxcIlwqBgO4P2aMitkMoIcJbSEislysp120cOYpEVEE5Bw/h/QJc21JbTGwExFBu+BXOCsozd96DACQteVoyPsIFQM7EVVoojHg0fMVMyo8njhfiHxnH4YVzOg8JSIiHRmvLAYA5E4ebcnxeMdORA4m3Jq+u2QX/vC/zSY0JnoY6VzNPnYW6RPmIvtYdAwVZWAnqmCKSkqxau9Jv+9LGMNEpizKwX/XHQz58y7FJaUoKinVfG9t7imkT5iLXcfPhX0cT+Fc1rK2OPLprry63RjYiSqYNxfl4I6pq7B+/092N8Wv4W9+i7bPzNN8b86mIwCA73afCGnfR89c8h6rr3Ed8wzyhr7IRNlSSwzsRBXMruOOYqwnz2t35oUzEiSQ0lKFd5bswpmLRbrb5Z68GLE29H/tGzwyY6OhbQN9ezlfWBxwH56nc+vhM4aOGy4GdqIo9cTMTXj2yy3WHdCCmTrfZOfhzUU5ePHrbe7XLhQWY7HFi1wsNrhakt5FbuvhM+jywgLM2XzE8HGve3el4W3DwcBOFKU+33AI/1l1wO5mmMqVN794uSwVMnHWFvzq3+tMz5mHw8glbtsRx933IzM2oqC41Pk5jaGTNsxsZWAnqqCMJlwKikpQ7KcjM5AVu/LR+fn5OFfgP/XiSrsYSWsA1pfkMnK8A86fQavejB3pdwZ2IovknS3Aap3RKFYJdAf53je78dOFy+7nHZ6bj7Efr3U//2LDITz2qbEc9ZuLcnDhcglyjp/3et0zAG46eFp3HxcKizVTInr574XbjmFd7qmA7fvg2z14KojhmZ+uOYD0CXNx6XJ0F0pjYCeyyJj3VuL2qVG8FLAzeC7JzsPEWd65/RW7ykagPD5zE7760Xhe2ZNvLD70k34n6cnzhej8wgK8983uoI7zwLT1uOWDHwJuN3leWYXxUgO31u8623HCT8ezFqZiiOLY8bPWTSl3KSlVKCkNPhdw0aI7Us+2lSqFrC1Hve7O850BdM5mc+qt6HWGfrHhsObrRuKyVgom1OGYZmBgJ4pjPV5aiL6vLvF6zZ4h194H1WrD35btwfjpG7Bgm/2TfFbtdaRxFLxbnuCMmEppD8n07Dxdve+Ue1urMbATxbGzBcVBpQ1czIpF+uW1vB05XQAA2J133u82kbZ0Zx7OXCrCTj8jdFyBu1Qp/G3ZHvfrruB9vrAIEz7fbLgjOFJYBIwoTt370RrN183M+Z6+eBlTFubg2es6onJSYrn3A10gtO5m/7wwx5zGhWDs/61FZttU9/NLRSV49JOyjmLXufP3c/1jxT4AQOPaVcp9xkq8YyeKU9/m5Ae1/aZDwc2KLC1VGPLGMkxbtR9fbdTvTF2x64TXpKRIZCdW7jphylj4fScuuB/P/vEIjp4pcD9PcEZp31y93s9jRyqGd+xEATz26UYcPVOAmb/ub3dTosqMNQdw5pJjfHqJT/T6fvcJ9GlZ1/38L4t3AQCa1qka1DG0gqJr4QrPImGvzduBv3+7F0BkS+O6br59+6N9a617tvuz9eEXRQsWAztRAKEO7YtlRurF5HkEM9/Nf/7P1Rg/tHW5jPrZS/4nKukd0TOdceL85XL7cgV1AEifMFdzH5eLS3HM4+47JO52eLf2R52x+HaMhmJgJzLJW4tycPxsASbf3NXuphhiZopAawz43vwL5V57e8ku049t1MRZW/C/9YcCbqeXEy9LxejvQ2v4o5WYYycyydtLduHTtdZ/7Q5WJPrytMKYXnDbeOAnv58z6p0gJy0F2+egxV8qJtowsBOZLFBJWrsZiUnBLrYRbKnfk86SBaHWoAlWUUmp4TVHD5665H7sSvu4lI2Kie7IzsBOcUkpFdG64nqWZFtbgjZUerHbde52HPVe6s1rgQqv7UNrw9VvLdd936z/wkmzt5V77YNv92hs6c13DoBW9UYtrs5iu4Qd2EWkmYgsFZEdIrJNRB4zo2FE4XhjwU60fDrL7/JqZCxofrqmrGxwQVEJOjw3X3M7rRx7sEHZ9yICAKPeWRHcTvzwrHXj4lknxijXxbA0yn+tzLhjLwbwhFKqI4B+AB4WkU4m7JcoZB9/nwsAKCy2/i/Q7C8KX/14GLvzzKtVvsjgohZvLsrBv37Y735+QWc2ZZStDBcxl2PkRiHswK6UOqqU2uB8fA7ADgBNwt0vkRlOX7wcUhEsM+w6fs69av1KjTtGox779EcMf1M/ZRHIuYIivLtkV1Dn4p0lxtMJ2p2n4S2M7Sn72DkopbzKCdvBNdLnzUU7bW1HIKbm2EUkHUAPAKs13ntARNaJyLr8/PB7p4n0uO4gB/1pKV6Zu92WNlz91nKM/MsKrM09hbs+LPcnoauw2Nzqiq9mZWPKohz8Z1XZHfjmQ/p10INhRX/G15uPYrtGusYOi3fk2d0EXaYFdhGpDuBzAL9VSpU7+0qpqUqpDKVURlpamlmHJQrINVPRKr4h7oTB0RiejIy3DoYrjTJ/a1nlxPc9ilhtCbKcgBFKmRvwj4cxuUgvjRSPTAnsIpIMR1CfrpT6wox9EsU7pRSW5+SjVCM9YiRl8s8Ve7HWwCpBnvxlRu6YGnhRCl/bj5RdDIwsUuHL3wxRPaFeJ07anMKxmhmjYgTAhwB2KKXeDL9JROEza5zxjNXBLyYdKKt8rqAIx84UYMG247jnozX46Lt9IbXtlbk7cKuBVYIAIPdk+VmgLgVFJbhgYGENV31xF890hFbAFTEvx07BMeOOfSCAuwFcJSI/Ov+NMmG/ROWcuViEni8vwvr9P+luF+jO7tiZAjwza0u5CTIXLxdjrsdqPZ5LxJm12PK1b69Av9eW4NgZx0SYg6f0l4cLpLRU4ZlZW3CnzrJ7m52pFt84e+lyCQ4YPP5indE0vkXAAOOjb8h8ZoyKWamUEqVUV6VUd+e/LDMaR+Rr/YFTOHXhMt77xviIjeNnC8sFz6e/2Izpqw9ghc/yZS98tQ0Pz9igWdSpywsLDB1v9ib9omGHfrqk+74//vLErSZmYfrqA/jBuVD2vhMX/M7oPHXBe1Zsx+e1x6Vr+fJH7aXjAOD1+TvdlR4jKdhO6H8s3+t3UlU848xTqhCumrLM67lWClsphcOnHUH3fEHonW3LDdYkeX1BcEPm/vVDLi4FSJkcPHURV/55GR79dKPm+1qTgIwKlPYPZaWmYPwxa0dIn3lrsX0Ld9iFgZ1i1itztqOjn5mQvjGoqMRPVFJA9rGzSJ8wFy2fzsL3e04GPK5SCoNfX4rP1vkv+HXZwMSoYBeMnr7qADo+Px+HT1/Ckh3aaY7M15cCALK2hLduqFmDWQKlzKwQzkU6VjGwk23W7z+FAxoLAuvxHK73z5X7cEnja7ZSKmBg9cw1z1wb/NDCA6cu4sn/bfb7frirBWkFVte3if0nLuCb7OgbR/2P5XsDb0SWYD12ss3Nf3OM6AhmxZuZ6xxBeOlO/+mOZSaUZ/U3qua5L7eiVVq1gJ+fHsJomryzBejz6hI8cmUbpNWo7He7aC0ZG60li6P0dEUU79jJMsUlpThXEPkOtpPnjY9ZDnZY5LRV+/Hi18HNZP1W5yLkaZ7z28h7S/XrjCsod0cpkRYGdrLMIzM24opJC+1uhpdlO/M1g3v+uUK88NVWU47xX51cvKcXPErLfrfbf22Zuz9co7k6kZmivd446WNgr4BOni+0fAjYpoOnMX9beB164Vq99yQ2OFfucY3g+LdH9UJPE77Y4lXZ0GoLOQacwsAcewXU65XF6J1eB589OMCyY97w1+/cjyfN3qaZp56y0DH874kR7cM6lr/6JLc7J/DkTh6NYn+jZJyMjGqJZ0YXlIgFFaWksCcG9gpqba51w9B8A62rVrqvd51rWIYb2AOZvno/so+V1Tc/ejr04lIPT9+AwuIS/PPe3mY0zRafbyg/KoipmNjGVAyZbnfeeaRPmOsuUBVoJqbZAoWkV+d6T3QJJ0U0d8tRLN6RZ9syfGb4+7flhykGmggVSz5ZE/wIpVjHwE6mW7nLMQrk8Zk/AgDyzgaekRhO/fGc4+eQPmEuvt/j6HAMlGYxUvAqWPd8tMb0fdppwwHzarWTNysWfmFgp4g5eOoSthw6E/BrfWmpCuuXfZVz6N8852xLO9IIWmtqGlVUqsIuBEaxI5QSx8FiYKeglJSqoNIOZw2MW281MQtnL5k37TvWsiIzVh9wlwIgMgMDO/mVf64Qu46fw2frDrqDeeuJWXhi5ibdz+XknXc/NhpkPQtILc/JR2FxCc5cNDaZyfcYz35pzvhzoljFUTFxbu7mo9h25AweHdYWKcmJQX229x8Xux9nHzuH567rBAD4YuNhvHl7d/d75wuLsXDbMdzUsyk+WXPAa3GK/acu4NWs7KCOe89Ha9C5cU1sO3LWXW4gfcJc3Nmnueb2x86GPqqFKB4xsMew0xcvY+H242hRtyr6tqpX7v3iklI8PGMDAMdCCAdOXsTwjg0C7vfI6UtoVCvF67UPV+7Ds6M7up8vz8nHmn2n8NDQ1nh21hZ8+eMRtEqrjqe/2OL1Oc9FK4Kx7YijvOyrWTswcZTjuP5GN3ztHHUT7DJxRPGKgT2GHDx1Ec3qVnU/7/7SIvdjrUJanv2RZy4WYd7WY+56JP5sOPATbnr/e7x+c9dy73kusOwaBVJUWoqjzkWGL14OPU/ubwW1qcv3Ymh7/cXPXQtXZB87h6NnQlvEgiieMMceI5Zm5yHz9aXI2hLaHbCRpSffXrwLN73/PQBg3f7yd79HNVaJLywqRaFzlqbWyBYj9c0DeWSG9qIRWgLl/4kqAgb2KLQ0O6/cAgXbnSvfbDl8RusjAZ0ysEr7B9/ucT82usyZUsq9jFw4E0H0FkMw0nYXMy4kRLGOgd1G87YcRfqEuUifMNdrHPPYj9fi5r99r/vZYO/cF2wrX1Rq0fbjfpcz09peywmPErnBrgjk6XadhZiJ4okVVXgY2G00x6NjcUGQ09rHT99Q7rXzhcVo98w8fJPtCMp78s+X28bTuH+vQ8YrZSNfikv1C19pjT7xvDAsM1h3nKgis2KaBQO7DeZsPoJr315haIbk8px8XDFpgXuF+p3HzmH7Ee0FiffkncflklL8ZfEuAHCnSAJJnzAXUxbu9L8uqNMMjVWBzhdWvPUkiaIdR8XY4NFPNqJUAZcN1Ed5Y8FOnCsoxvvLHPnvb7LzDK93GcwMTFdlxWBt83ORISL7mBLYReQjANcByFNKdTFjn9Fg5tqDWJt7Cm/c2g3pE+ZiZOeG6NG8NkqUwvihbQzv51xBERJEUK2y43SLCKAU9nisguMvCCckGM/I+e6CpVeJKiazUjEfAxhp0r6ixlOfb8ZnHmO35287htfmZeP1+TsNpzkA4IpJC9HjZceY87yzBUEVvEo0GNe/3Hi43MiS0wan5BNRfDHljl0ptVxE0s3YV6z4Yc9JdG9W2/38QmExdh4/h3YNaqBqcmK5O+3LxaVQSuF3zlK2vg6cuoj0CXPx4vWdvV5PNHjH/tv/lu1386EzOH3xMt5YsNPoj0NEcYQ59hD5Tvh57NONWLzDkfu+b0A6Jl3fGaWlCpsOld3ZT1u1H9/t1h5nPW2VY31NzwWNC4pKQl7pyHNWKhFFDyuqj1o2KkZEHhCRdSKyLj/f2mFxH3+3D0t2mLs4sCuuHzx1ET1fXuQO6gDwxYZDOH62AH9Zsgs3vl82Hv35r7YhGB2em29GU4mogrHsjl0pNRXAVADIyMiwtFdv0tfbAWjXUwGALYfOoEuTmo5OTYNEgF9PW6c5kedsQTH6vroktMYSEYWpQo1jv+7dFV7PD/10Ec/M2oIx763EfzTGaLscOV2+sNT5whLDszOJiFwuWDD3w5TALiKfAPgBQHsROSQi95uxX3+mLt+DwQZWnJm+er/XzMith88i71wBzhUU4S+LczDoT0sx3RnQczxWrV+z75RXudmdHu+5vLNkVzg/AhFVUJ5VUiPFrFExd5qxH6NcCzc8MmMD7urXAv00apHvO3EBz8zaiikLc7xez/zTUvRpWbfcGpXTVu3HSzd0hojgtr//ELnGE1GFZsX8kpgeFTNn81HM2XxUM3de4qx74lsZsLC41O+U/NfmZQe9yhARUTCsGBUT04Fdy5788ygoKkHlJP8B2t95nbp8r+brYz9ea0LLiIisEXOB/dsc7aGS6RPmej2fNX6AFc0hIgoKqzv6KCopxb3OJdkC8Rw/7iuYhRuIiGJNTAV234WSXXzv1omIKrKYCuxzNh+xuwlERGGJq5ICZigo0l/hh4iIYiywExHFOivGsTOwExHFGQZ2IiILMcdORERBY2AnIoozDOxERHGGgZ2IKM4wsBMRWUhZ0HsaU4F9w3NX290EIqKwHNZYkc1sMRXY61arZHcTiIjCYsUM+pgK7J46Narpfrz9pWtsbAkRUXSJycBeu2oynh3d0f28aqUk9Emva2OLiIiiR8wF9i2TRmDV08PKvd6jee1yr+3+47VWNImIKKrEXGCvkZKsuS7pk9e0x/ihrb1eS0os+/F2McgTUQURc4HdV40Ux+p+SYkJeGpkh3LvzxjXF3+/uxeSExOw/aVr8MyojuW2ISKyihXDHWNuzVMX16np0riW7nYDWqe6H1etlIRxg1shOVEw6evtEWwdEZF9YvaO3XXRE9F+/9+/7OP/sx6Pd74y0uu9j8f2dj/+863dQm0eEZGmmFnMWkRGishOEdktIhPM2GcgrmL1voG9UpLjRxrcLs3vZ5vUruJ+XDkpEd2aOu76v3x4IIa2r+9+72fdG6N65SRUrZSIFU9daVbTiYgiKuxUjIgkAvgrgKsBHAKwVkRmK6UimutoUDMFANCtqfdomMW/G4Kc4+d0Pzuic0N8Mq4feqfXAVD+CppavTJOnC9EUmICtr7of4x8p0Y10bFRTXy+4VBQbe/YqCbm/GYQth85izHvrQzqs0QU26pVjnwG3Iw79j4Adiul9iqlLgP4FMANJuxXV7sGNZD1aCYev7qd1+vN61XF8E4NAn6+f+t67lEzvxzYEgCQXq8qAGDuo4MwY1xf3c+Py2yJrMcy0V1jmKU/lZ3fJuY9lonEBMEVTfX7B4go/qQklR/VZzYzAnsTAAc9nh9yvuZFRB4QkXUisi4/P9+EwwKdGtf0GtIYqp/1aILcyaNRu6qjZEGDmilena56fHu4p92vndsf2bkh1j47HGufGR5SG8cOTA/pc0QUXWJlzVOt7styLVdKTVVKZSilMtLS/Oe/Y4U4k/tXdajv9XpmW/8/W82UZKTVqBzS8VzfKogotsXK0niHADTzeN4UwBET9hvVXJ22TetURe7k0WHv74UxnQAAy34/FE+NbI+Xf9bF6/2qlUL/+jamW+Ow2kZEscWMwL4WQFsRaSkilQDcAWC2CfuNOu0aVHc/Fo0vKjf1cGSgvn1yKFqlVgtq32MHtkTu5NFIT62G8UPboLS07LKenCioVz20O30ASK3OqphE0SIm6rErpYoBPAJgAYAdAGYqpbaFu99otPB3Q/DOnT0AAC2cHa0uuZNH483buzvfq4asxzIBOIZW9mpRB7+/pr3f/VbTuBsvKikr7fnFQwM1Pzf9V33R0aPKpZa5jw4K6qvf337R02ssPxHFHlPG3SilsgBkmbGvaDemayOkVq+E/q3q6W6XkpyID+7qiZ7N66C+c2imlsk3XYE+LctXprylV1O8MncHAKCtxzcFTwPbpKJvy7rYcfSs3/13blwLn8H4cMx2DWugdZr28YgofDEzQakiEREMaJ3q7jzVM7JLI92gDgB39GmOVhqBtHbVSsidPBq5k0e7i559+fBAvH5zV6/tbs1oGrAdnl/9mtct+6bRrkF1zPx1/4CfJyLzxErnKVmke7PauK13M/zn/r74ZFw/AI47cn+dt0PbO0bojPUZUeOqglm/RgoSAl+fDBnYpuwbzB29m+lsSVSxxcpwR7LYoLap6N/afypo/bPDsfuP1+Kjex258vTUalj/rGP8fO/0uu7PGvkFc02qCiQlKRFPjXT0I2hV2fSHI3aoouEdO4WkXvXKSEpMQILH7Xi96pWx6HeD8epNXTRH9PjzvwcHaL5+Yw/vOWi/ymyFh4a0RvbLI1G3WiU0qqWfgnJJNOkbA1GsYI6dDGtTP3CHZ9sGNVDZYzqzUo4SDMH6YvwAvOUcAQQ4RgT1b10PIuLuD5j/28GG92f0IkAUD3jHTobN+c0gw9t69vvWr5GCazr7r61Tp1oyAO2USdajmVj6+6Gan6tVJdlwezLbGivfQETGMLDHCa3lAgMxcufQtE5VzP9tJp7zWDzcdV3o1LgmWgY5EUtLZ+diKeMyWTaB4l+9apGfMMjAXgH5prUD5dw7NKyJZI9ia4EmRQXrnv4tMP+3mfhDEJ2unob51OsJV7dmxit2VhQ1LCg1W1F0bmzu348WBvY48vYd3fGKT40ZPcEMu3KtLdu+QQ3D3w6+engg6hq4OxERdGio/8u+4qkr0aFhDc33/nFPhqH2GMUO3fJqBpFaI33sPKWg3NC9Ce7q1yLgdt2a1Uab+tXdd8i39Ao8ySkpMQG5k0djwe+Md4p2a1YbG567GjN+pV/b3ohmdavin/dqB/AEP4PxQy11LCJeE7mIYg0DewVUrXISFj8+BD2aO1aQGt6pQbmiZZltU3GrgYBvxIA2qVj+5JW4okkt9761eN7JGJjYG1D7Bo47/NsMzM71de0VDTVf79Ik8l+jKb5xVAzZZtr9ffGGiYt5ew6r/P2I9njrdv1973ut/GzaZJ1FVVzF2VyMpIC0NKqVgievae/3+/KEkR213yCKIgzsZJlAd+GJzg0ecpY88NWgZgreuKUrptzaDaN87qiv9xmO+eV47YqYnva8Ogr39m/hVVbhh6eHoZ9OgbfkRMGWSSPcz8Opkx9LglkCkuzHwE6WCfQVNCFBkDt5tO7omFszmuHmXk3x/i964dnRHf3emTerWyVgexITBC/e0AV3a/RL6DW1RkpZR+L2l0YGPE48eP66TnY3IW6wVgxZpqtzYe3qFg1rMyPP+KvMVtjw3NWGtm1bvzp+NUh7nPzTozrirz/viX2vjXK/Fsq8AH8+9NPp62vWeO3yDXa7s09zNAhQpZSiCwM7AQAm39wVXz08MKJ/wFqpGCPlj0M7lriPpxSw6PEheHxEO81tU5ITMbprI6+2jB/aGo1qpaB3ep2A7c16NFO3LcM6+p/Z68nzm4CW8X5SVBRb2HlKlklJToz7iTlVKyUhd/Jo/Hpwq4DbpiQn4oenh+GzBwfg5p5lo2q0rkOdGtfEzlfMSMn4/4sfOzAd469sY8IxIqtXizqBN6KIY2AnW+mt/+gaHmm2p0d1DGoB8im3dQsYsDyLqwEoN3x0/m/17+oB/Tu5F8Z0tixNFo5OJs9KjkecoERxxTNwGcnAfPpAv7COF0x5YuP7dPi/sb0x9e5efrf7xqc4WqCZtVaJdMG1mlWi/+JTETCwky2M5BmrRdEd6ss3dMGA1vXQxfkt4sr29TGis/YkplCZeSdXu2r5fP2H92ZgiolzE7T85qq2Ed1/NOrWNDLfLMPBwE6WCbXzNBrytp0a18SMcf0MjZb56L7Ao2Cm3NoNGT4/V1r1ygE/l2ygkM3bd3THj8+PKPf6sI4NvDpoe0RgbLqZo4lihZGSHJ700o9mYWCnqLb48SH41y/72N2MoBipM9NCY4GTOtUqBQzcr93UVff9QDzHUBtND/lbxKW/zkSuisSKnHmwGNgpqrWpXz3sTsNo/MPzx9UJa3Zp5HCMHZBe7rXsl0di2v3RccENpiO8oggrsIvIrSKyTURKRcTc2qlEIXL/oUd5+V29LNTMX4fXcWyMsUueVvXMlOREJOnU7qlIfM/O0PZpuss9xsI49q0AbgKw3IS2EJmqaR1HWQF/ddztJKL9B66Vf93hUbbASH5Wb5NEjyAdaqG0QOoNQS1ZAAANCElEQVR4dNxe17VRUMs22umO3s1M2c+4zFaWBG89YQV2pdQOpdROsxpDFYdrxIvW6A2zDGiditmPDMQvB8beknuencpV/BQaG921UdD7rZyUiDm/GYStL16DSomR6ejc+PwI3OdM3/RsXsc9kijapdUI3Hlthsa1A9cxChe/S5Fl6tdwfD2tnJyAEZ0a4JWfdQl5OTyjujat7XchDruFU03h2i4N8fbt3TXvhgPtt0uTWjEx2SlmaJxwvf+DPi3rRrAxDgEDu4gsFpGtGv9uCOZAIvKAiKwTkXX5+fmht5hi1pTbumHKrd3QoWFNiAju6teiQg6Pc2nnXAjk532b4/1f9DT0GdedfJVKjhy31t2w0TRAJKsMugJbaQg5ieyXR6KnDWWCzUqfiIn7ClXAy7ZSargZB1JKTQUwFQAyMjJiaaACmaRWlWTcbNKqTPHg+TGdMKZbY6/67xNHd8Qzs7YiJSn6vkyvmTgMfV5dYmhbV239UAJcSnIimtWtig0HThv+jL8+CytoDfuMUG07w6Lvt4coRhmpAe+pclJiuUU9ftG3BXInj/Y74kSr89SzSFn1ykkY2MZRNuDLhwfizybONE01MIHKxdVJW1waWrRNDDJ9tvzJK0M6jietbzBG/k/9jfO3U7jDHW8UkUMA+gOYKyILzGkWUey5b4CjkzatRuRrl3vWwZlyW1nw3vriNe5OwO7NauvOihzcLs39eEQn79LCb9wS3kQoV79GKKkYAHh2dOCFPb5+ZJC7hn2zCC0+HlLzbfz24BJWD4pSahaAWSa1hSim3T+oJe73s5hHedZ9V/94bG/NQNOzuX6phsHt0rA8J7T+sCTXHXtJaBHOyFDMK0yu0aI9/DTEfdk8LY5d40Rx4K3buyElSbsjemj7+kHvTwH4xz29cK6guNx7RvLHN/Zogr8u3Y0x3YIfkumy7PdDMfTPy0L+vF1apVaPSGXRYDCwU9zo27JuxCbdRLsbe4TXKd2+YQ0s3H7c67XKSYmoXL3sYhHMPWirtOrY+1rwU/09//8a6szejASz7rEb1krhHTuRWf776/52NyHiIhUuHhvWFplt03Db338IuK3WvejWF6+B0f7OetUq4eSFywCAJrWr4PDpS+737JxyYGZe3O4cO0fFEFkkd/JodHcuP2j3cDhfSYkJ6NOyLlr6rPzkydXkIR6dri7VKyehaiVj94kTri2blOZ7Hqwu0ZycKLi6k/81aW/u2cTreaxM7GJgJ4ohkb4euOqlNK5VfphfQoJg+ZNX4m93+V85KlgJPpH9xh5lgVTr4ndnH2P1XNo3KKsPpFeQCyhbKEMrfXJjz6Z4bFjZ4iFGSjID9i8SExuXH6IYNPXuXtiVd97uZgRlXGYr9GpRBxnp2tPem2vUkQ+WZx2cYFMvnguF6Jl0fWf8/J+rAqZElPJoj4H0idFvWs3rVsW+Exe8XvvP/X2DnusQKt6xE0XIiM4N8fCVbexuRlASEsRvUDeL58pNeitoaY0sSTJ4JUgQuOsQVTY4i9fMtLjWvto2qI4W9fynuszEwE5koXCDR6zX4sidPBqt08pmamqUz/L72QcGtzJ8oRQRPDikNXInj9adxaoQ+C7c8327O0WNYmAnskG4uXKzOl8XPz7EPXszkjLbpmq/ofNz+P6ME0d1tCV3bSSYP3xl64DbWNlfzhw7kYXGdG2ETQdPo4kFNbmNsKrOycdj+2iWF4hUsDN64VNKuVM+4Swy/eQ1HfDXpXu89msnBnYiC90/qGWFLFecmCBI1AjjvqNiPJ+GE/SNly7Wvwgkini9P6B1PWw/ejaMllmDqRgiC4lIhQvqeswcz//Zg+FNUPO8GHz6QD9MGtOp3CggzzH4evymnizCwE4UQ67u2ACt0qrhwSGBc7pWqZES+hd/35Evns/0Rsxo6e0xmsd4KqZsqTrXGrkA0K9VPdznXFJxUJuyIG10Ae9xma3w3YSrjDUiApiKIYohdapVwjdPDLW7GW6bXhhheAiiFt8AHO66o71a1MH6/T8F9ZkxXRuhVpVkDGxdD5O+3l7ufb3hnx/4mawlImjg87NUD+MCGCzesRNRyGpVSQ5rpMrd/Vu4H//vwf7o4VFKOJwsjdHPPjS0NUQEQ9qllcv3B9KhYQ2M7NLQ0La7/3it4ZILZmBgJyLLue7yf9G3LLBHemKUS9ajmQAcE5civZi6i9EUjmnHs/RoRGRYZttUHP7pUuANY9CcRwdhabb+Ih6hdKy+dtMV+NO8bN1FOFqlOWZ/1qpirDxBLGJgJ4pS0+7va3cTIqZDw5ro0LCm6ftt16AGPryvt+42KcmJePlnXTDYZ+RKtFXcDAdTMUQUlUQEnz8UmRr7d/drEXbdlmguL8DATkRRq1eL8PPunsMm61SN3/SLJwZ2IqowVv7BvrHlVmJgJ6IKw+4FMKzCwE5EFGcY2ImIIkivHnykhBXYReQNEckWkc0iMktEagf+FBERRVK4d+yLAHRRSnUFkAPg6fCbRERkvWCLjv28b3ND29lRmz2sngSl1EKPp6sA3BJec4iIAqticenj2Y8MRE2fhbTvHZCu+5lgLxRmMrOL+JcA/mvi/oioAvjXL/tg5S798gKe1j87HMkGF6g2S9emsZVlDhjYRWQxAK0SZs8opb5ybvMMgGIA03X28wCABwCgeXNjX2GIKP4NaZeGIe3SDG9fr3p4pX0rgoCBXSk1XO99EbkXwHUAhimdZJJSaiqAqQCQkZERxZNxiYjCZ+e6p2GlYkRkJIA/ABiilLpoTpOIiOKHHbn2cHPs7wGoDGCRs/GrlFIPht0qIqIYk/VoJgqKS+xuBoDwR8W0MashRESRYNX9cqfG5pchDhVnnhIReRjRqYHdTQhbxaiIQ0QV3lcPDwy4zeqJw1A7Dkr7MrATUYVQpVLgSU0NaqZY0JLIYyqGiChIrVIDr75k55hu3rETUVxrUDMFu/LOo1KiOfexsx8ZiGZ1qhre3o7CAgzsRBTVFj8+GEkJoQfld+/sgW+y85Bu4C7biFgoL8DATkRRrU39GmF9vk61Sri5V1OTWhMbmGMnIoozDOxERHGGgZ2IKIKsrh0PMMdORBQRyYkJmDiqA67qUN/yYzOwExFFyAODW9tyXKZiiIjiDAM7EVGcYWAnIoozDOxERHGGgZ2IKM4wsBMRxRkGdiKiOMPATkQUZ0Qp68vBi0g+gP2WH9hcqQBO2N2IKMLzUYbnwhvPh7dwzkcLpVRaoI1sCezxQETWKaUy7G5HtOD5KMNz4Y3nw5sV54OpGCKiOMPATkQUZxjYQzfV7gZEGZ6PMjwX3ng+vEX8fDDHTkQUZ3jHTkQUZxjYAxCRkSKyU0R2i8gEjfcfF5HtIrJZRJaISAs72mmFQOfCY7tbRESJSFyPhDByPkTkNufvxzYRmWF1G61k4G+luYgsFZGNzr+XUXa00woi8pGI5InIVj/vi4i84zxXm0Wkp6kNUErxn59/ABIB7AHQCkAlAJsAdPLZ5koAVZ2PHwLwX7vbbde5cG5XA8ByAKsAZNjdbpt/N9oC2AigjvN5fbvbbfP5mArgIefjTgBy7W53BM/HYAA9AWz18/4oAPMACIB+AFabeXzesevrA2C3UmqvUuoygE8B3OC5gVJqqVLqovPpKgBNLW6jVQKeC6eXAbwOoMDKxtnAyPkYB+CvSqmfAEAplWdxG61k5HwoADWdj2sBOGJh+yyllFoO4JTOJjcA+LdyWAWgtog0Muv4DOz6mgA46PH8kPM1f+6H4yocjwKeCxHpAaCZUmqOlQ2ziZHfjXYA2onIdyKySkRGWtY66xk5H5MA3CUihwBkAfiNNU2LSsHGlqBwzVN9ovGa5jAiEbkLQAaAIRFtkX10z4WIJAB4C8B9VjXIZkZ+N5LgSMcMheOb3AoR6aKUOh3httnByPm4E8DHSqkpItIfwDTn+SiNfPOijuHYEgreses7BKCZx/Om0Pj6KCLDATwD4HqlVKFFbbNaoHNRA0AXAMtEJBeOvOHsOO5ANfK7cQjAV0qpIqXUPgA74Qj08cjI+bgfwEwAUEr9ACAFjropFZGh2BIqBnZ9awG0FZGWIlIJwB0AZntu4Ew//B2OoB7POVTdc6GUOqOUSlVKpSul0uHob7heKbXOnuZGXMDfDQBfwtG5DhFJhSM1s9fSVlrHyPk4AGAYAIhIRzgCe76lrYweswHc4xwd0w/AGaXUUbN2zlSMDqVUsYg8AmABHL3+HymltonISwDWKaVmA3gDQHUAn4kIABxQSl1vW6MjxOC5qDAMno8FAEaIyHYAJQCeVEqdtK/VkWPwfDwB4B8i8js40g73KecQkXgjIp/AkYJLdfYpvAAgGQCUUh/A0ccwCsBuABcBjDX1+HF6XomIKiymYoiI4gwDOxFRnGFgJyKKMwzsRERxhoGdiCjOMLATEcUZBnYiojjDwE5EFGf+HwB96ndjyhotAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "#从正太分布中抽取随机数\n",
    "#均值为0.5，方差为[0.10000, 0.20000, 0.30000, 0.40000, 0.50000, 0.60000, 0.70000, 0.80000, 0.90000]\n",
    "x = torch.normal(mean=0.5, std=torch.arange(0.1, 1, 0.0001))\n",
    "print(x)\n",
    "#均值为[0.10000, 0.20000, 0.30000, 0.40000, 0.50000, 0.60000, 0.70000, 0.80000, 0.90000]，方差为0.5\n",
    "y = torch.normal(mean=torch.arange(0.1,1,0.1),std=0.5)\n",
    "print(y)\n",
    "#均值为[0.10000, 0.20000, 0.30000, 0.40000, 0.50000, 0.60000, 0.70000, 0.80000, 0.90000]\n",
    "#方差为[0.10000, 0.20000, 0.30000, 0.40000, 0.50000, 0.60000, 0.70000, 0.80000, 0.90000]\n",
    "z = torch.normal(mean=torch.arange(0.1,1,0.1),std=torch.arange(0.1,1,0.1))\n",
    "print(z)\n",
    "plt.plot(torch.arange(0.1, 1, 0.0001).data.numpy(),x.data.numpy())\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 序列化与反序列化"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[ 0.6560, -1.4981,  0.7547],\n",
      "        [ 0.6677,  0.4550,  1.3934]])\n"
     ]
    }
   ],
   "source": [
    "#序列化模型\n",
    "x = torch.randn(2,3)\n",
    "#序列化torch.save方法\n",
    "torch.save(x,\"randn\")\n",
    "#反序列化torch.load方法\n",
    "x_load = torch.load(\"randn\")\n",
    "print(x_load)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 并行化"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "6\n",
      "4\n"
     ]
    }
   ],
   "source": [
    "#torch默认的线程数量等于计算机内核个数\n",
    "threads = torch.get_num_threads()\n",
    "print(threads)\n",
    "#可通过set_num_threads方法，设置并发数\n",
    "torch.set_num_threads(4)\n",
    "threads_1 = torch.get_num_threads()\n",
    "print(threads_1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 元素级别的数学运算"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 78,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([1.0000, 0.5000, 0.1230, 0.4000, 0.5000, 0.9900])\n",
      "tensor([2.0000, 2.5000, 2.8770, 3.4000, 3.5000, 3.9900])\n",
      "tensor([0.5403, 0.8776, 0.9924, 0.9211, 0.8776, 0.5487])\n",
      "tensor([3.1416, 2.0944, 1.6941, 1.1593, 1.0472, 0.1415])\n"
     ]
    }
   ],
   "source": [
    "#求元素绝对值\n",
    "a = torch.Tensor([-1,-0.5,-0.123,0.4,0.5,0.99])\n",
    "print(torch.abs(a))\n",
    "#每个元素加“n”\n",
    "print(torch.add(a,3))\n",
    "#余弦\n",
    "print(torch.cos(a))\n",
    "#反余弦\n",
    "print(torch.acos(a))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 51,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([-0.8000,  0.4000,  0.0770,  0.6000,  0.6000,  1.2900])\n",
      "tensor([-0.2000,  0.4000,  0.6770,  2.2000,  5.4000,  1.2900])\n"
     ]
    }
   ],
   "source": [
    "#tensor相除再相加:a+0.1*tensor_a/tensor_b,返回新的结果;需要注意的是a元素个数需要等于tensor_a/tensor_b元素个数\n",
    "x = torch.addcdiv(a,0.1,torch.Tensor([4,9,4,6,7,3]),torch.Tensor([2,1,2,3,7,1]))\n",
    "print(x)\n",
    "#tensor相乘再相加:a+0.1*tensor_a*tensor_b,返回新的tensor;需要注意的是a元素个数需要等于tensor_a*tensor_b元素个数\n",
    "y = torch.addcmul(a,0.1,torch.Tensor([4,9,4,6,7,3]),torch.Tensor([2,1,2,3,7,1]))\n",
    "print(y)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 143,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[-9., -1., -8., -0.,  2.],\n",
      "        [ 2.,  4.,  5.,  9., -7.]])\n",
      "tensor([[-10.,  -2.,  -8.,  -1.,   1.],\n",
      "        [  1.,   3.,   4.,   8.,  -8.]])\n",
      "tensor([[-5.5000, -1.9000, -5.5000, -0.3000,  1.5100],\n",
      "        [ 1.2000,  3.8000,  4.0100,  5.5000, -5.5000]])\n",
      "tensor([[-0.9120, -0.1900, -0.8000, -0.0300,  0.1510],\n",
      "        [ 0.1200,  0.3800,  0.4010,  0.8880, -0.7600]])\n",
      "tensor([[ 9.1200,  1.9000,  8.0000,  0.3000, -1.5100],\n",
      "        [-1.2000, -3.8000, -4.0100, -8.8800,  7.6000]])\n",
      "tensor([[-0.1096, -0.5263, -0.1250, -3.3333,  0.6623],\n",
      "        [ 0.8333,  0.2632,  0.2494,  0.1126, -0.1316]])\n",
      "tensor([0.8138, 0.9129, 0.5130, 0.4994, 0.3356])\n",
      "tensor([2., 4.])\n"
     ]
    }
   ],
   "source": [
    "#向上取整\n",
    "x = torch.Tensor([[-9.12,-1.9,-8,-0.3,1.51],[1.2,3.8,4.01,8.88,-7.6]])\n",
    "print(torch.ceil(x))\n",
    "#向下取整\n",
    "print(torch.floor(x))\n",
    "#夹逼函数，将每个元素限制在给定区间范围内,小于范围下限的被强制设置为下限值；大于上限的被强制设置为上限值\n",
    "print(torch.clamp(x,-5.5,5.5))\n",
    "#乘法\n",
    "print(torch.mul(x,0.1))\n",
    "#取相反数\n",
    "print(torch.neg(x))\n",
    "#取倒数\n",
    "print(torch.reciprocal(x))\n",
    "#取平方根倒数：每个元素的平方根倒数\n",
    "print(torch.rsqrt(x[x>0]))\n",
    "#求平方根\n",
    "print(torch.sqrt(torch.Tensor([4,16])))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 125,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([1.3333, 3.0000])\n",
      "tensor([3., 3.])\n",
      "tensor([0.1000, 0.3000, 1.0000, 0.0000, 1.0000])\n",
      "tensor([0.1000, 0.3000, 1.0000, 0.0000, 1.0000])\n",
      "tensor([0.1000, 0.3000, 0.0000, 0.0000, 0.0000])\n",
      "tensor([1., 3.])\n",
      "tensor([1., 2.])\n"
     ]
    }
   ],
   "source": [
    "#除法\n",
    "x =torch.Tensor([4,9])\n",
    "y = torch.div(x,3)\n",
    "print(y)\n",
    "z = torch.div(x,y)\n",
    "print(z)\n",
    "#计算除法余数fmod和remainder方法,相当于python中的%算子\n",
    "q = torch.Tensor([2.1,2.3,5,6,7])\n",
    "print(torch.fmod(q,2))\n",
    "print(torch.remainder(q,2))\n",
    "#返回浮点数的小数部分\n",
    "print(torch.frac(q))\n",
    "#四舍五入\n",
    "print(torch.round(y))\n",
    "#指数运算\n",
    "a = torch.exp(torch.Tensor([0,math.log(2)]))\n",
    "print(a)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 122,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([1., 2.])\n",
      "tensor([1.3133, 2.1269])\n",
      "tensor([1.4427, 2.8854])\n",
      "tensor([0.4343, 0.8686])\n",
      "tensor([2.7183, 7.3891]) tensor([ 7.3891, 54.5982])\n"
     ]
    }
   ],
   "source": [
    "#自然对数\n",
    "x = torch.Tensor([math.e,math.e**2])\n",
    "print(torch.log(x))\n",
    "#对输入x加1平滑处理后再求log\n",
    "print(torch.log1p(x))\n",
    "#2为底的对数\n",
    "print(torch.log2(x))\n",
    "#10为底的对数\n",
    "print(torch.log10(x))\n",
    "#幂运算\n",
    "print(torch.pow(x,1),torch.pow(x,2))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 123,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([10., 10., 10., 10., 10., 10., 10., 10., 10., 10.])\n",
      "tensor([0., 1., 2., 3., 4., 5., 6., 7., 8., 9.])\n",
      "tensor([5.0000, 5.5000, 6.0000, 6.5000, 7.0000, 7.5000, 8.0000, 8.5000, 9.0000,\n",
      "        9.5000])\n"
     ]
    }
   ],
   "source": [
    "#线性插值：outi=starti+weight∗(endi−starti)\n",
    "#带`_`线的方法为in-place类型算子，不会创建新的tensor而是改变原tensor的值\n",
    "x = torch.zeros(10).fill_(10)\n",
    "print(x)\n",
    "y = torch.arange(10).float()\n",
    "print(y)\n",
    "z = torch.lerp(x,y,0.5)\n",
    "print(z)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 145,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([-5., -4., -3., -2., -1.,  0.,  1.,  2.,  3.,  4.])\n",
      "tensor([0.0067, 0.0180, 0.0474, 0.1192, 0.2689, 0.5000, 0.7311, 0.8808, 0.9526,\n",
      "        0.9820])\n",
      "tensor([-1., -1., -1., -1., -1.,  0.,  1.,  1.,  1.,  1.])\n",
      "tensor([-0., -1., -1.,  0.,  2.,  2.])\n"
     ]
    }
   ],
   "source": [
    "#求每个元素的sigmod值\n",
    "#sigmod计算公式为1/(x+maht.e^(-x))\n",
    "#sigmod值位于[0,1],可视为概率值，在激活函数中应用较广\n",
    "x = torch.arange(-5,5,1).float()\n",
    "print(x)\n",
    "print(torch.sigmoid(x))\n",
    "#符号函数，根据元素的正负，返回+1和-1,元素0 返回0\n",
    "print(torch.sign(x))\n",
    "#截断值（标量x的截断值是最接近其的整数，其比x更接近零。简单理解就是截取小数点前面的数）\n",
    "print(torch.trunc(torch.Tensor([-0.9,-1.2,-1.9,0,2.1,2.7])))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 规约计算"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 188,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[2., 3., 4., 5., 6.],\n",
      "        [9., 8., 7., 6., 5.]])\n",
      "tensor([[2.0000e+00, 6.0000e+00, 2.4000e+01, 1.2000e+02, 7.2000e+02],\n",
      "        [9.0000e+00, 7.2000e+01, 5.0400e+02, 3.0240e+03, 1.5120e+04]])\n",
      "tensor([[ 2.,  3.,  4.,  5.,  6.],\n",
      "        [18., 24., 28., 30., 30.]])\n",
      "tensor([[ 2.,  5.,  9., 14., 20.],\n",
      "        [ 9., 17., 24., 30., 35.]])\n",
      "tensor([[ 2.,  3.,  4.,  5.,  6.],\n",
      "        [11., 11., 11., 11., 11.]])\n",
      "tensor([  720., 15120.])\n",
      "tensor([18., 24., 28., 30., 30.])\n",
      "tensor(55.)\n",
      "tensor([20., 35.])\n"
     ]
    }
   ],
   "source": [
    "#计算累积(Cumulative),可以通过dim指定沿着某一个维度计算累积\n",
    "x = torch.Tensor([[2,3,4,5,6],[9,8,7,6,5]])\n",
    "print(x)\n",
    "print(torch.cumprod(x,dim=1))\n",
    "print(torch.cumprod(x,dim=0))\n",
    "#计算累和\n",
    "print(torch.cumsum(x,dim=1))\n",
    "print(torch.cumsum(x,dim=0))\n",
    "#计算所有元素的乘积\n",
    "print(torch.prod(x,dim=1))\n",
    "print(torch.prod(x,dim=0))\n",
    "#计算所有元素的和\n",
    "print(torch.sum(x))\n",
    "print(torch.sum(x,dim=1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 174,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor(2.)\n",
      "tensor(8.)\n",
      "tensor(5.8310)\n",
      "tensor(5.)\n",
      "tensor([20., 35.])\n",
      "tensor([11., 11., 11., 11., 11.])\n",
      "tensor([ 9.4868, 15.9687])\n",
      "tensor([9.2195, 8.5440, 8.0623, 7.8102, 7.8102])\n"
     ]
    }
   ],
   "source": [
    "#距离公式，常用于模型损失值的计算。计算采用p-norm范数。p=1为曼哈顿距离；p=2为欧氏距离,默认为计算欧式距离\n",
    "x = torch.Tensor([[2,3,4,5,6],[9,8,7,6,5]])\n",
    "y = torch.Tensor([[2,3,4,5,6],[4,5,7,6,5]])\n",
    "print(torch.dist(x,y,p=0))\n",
    "print(torch.dist(x,y,p=1))\n",
    "print(torch.dist(x,y,p=2))\n",
    "print(torch.dist(x,y,np.inf))\n",
    "#norm范数\n",
    "print(torch.norm(x,p=1,dim=1))\n",
    "print(torch.norm(x,p=1,dim=0))\n",
    "print(torch.norm(x,p=2,dim=1))\n",
    "print(torch.norm(x,p=2,dim=0))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 196,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor(5.5000)\n",
      "tensor([4., 7.])\n",
      "tensor([5.5000, 5.5000, 5.5000, 5.5000, 5.5000])\n",
      "tensor(5.)\n",
      "(tensor([4., 7.]), tensor([2, 2]))\n",
      "(tensor([2., 5.]), tensor([0, 4]))\n",
      "tensor(4.7222)\n",
      "tensor([24.5000, 12.5000,  4.5000,  0.5000,  0.5000])\n",
      "tensor(2.1731)\n",
      "tensor([4.9497, 3.5355, 2.1213, 0.7071, 0.7071])\n"
     ]
    }
   ],
   "source": [
    "#均值·中位数·众数·方差·标准差\n",
    "#均值\n",
    "x = torch.Tensor([[2,3,4,5,6],[9,8,7,6,5]])\n",
    "print(torch.mean(x))\n",
    "print(torch.mean(x,dim=1))\n",
    "print(torch.mean(x,dim=0))\n",
    "#中位数，指定dim将返回两个tensor，第一个是中位数，第二个tensor为index索引\n",
    "print(torch.median(x))\n",
    "print(torch.median(x,dim=1))\n",
    "#众数\n",
    "print(torch.mode(x))\n",
    "#方差\n",
    "print(torch.var(x))\n",
    "print(torch.var(x,dim=0))\n",
    "#标准差\n",
    "print(torch.std(x))\n",
    "print(torch.std(x,dim=0))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 数值比较运算"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 213,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[1, 0, 1],\n",
      "        [1, 0, 1]], dtype=torch.uint8)\n",
      "True\n",
      "False\n",
      "tensor([[1, 0, 1],\n",
      "        [1, 0, 1]], dtype=torch.uint8)\n",
      "tensor([[0, 0, 0],\n",
      "        [0, 0, 0]], dtype=torch.uint8)\n",
      "tensor([[1, 1, 1],\n",
      "        [1, 1, 1]], dtype=torch.uint8)\n",
      "tensor([[0, 1, 0],\n",
      "        [0, 1, 0]], dtype=torch.uint8)\n",
      "tensor([[0, 1, 0],\n",
      "        [0, 1, 0]], dtype=torch.uint8)\n"
     ]
    }
   ],
   "source": [
    "#元素相等比较，相等返回1，不相等返回0\n",
    "x = torch.Tensor([[2,3,5],[4,7,9]])\n",
    "y = torch.Tensor([[2,4,5],[4,8,9]])\n",
    "z = torch.Tensor([[2,3,5],[4,7,9]])\n",
    "print(torch.eq(x,y))\n",
    "#比较两个Tensor是否相等\n",
    "print(torch.equal(x,z))\n",
    "print(torch.equal(x,y))\n",
    "#逐一比较tensor1中元素是否大于等于tensor2中元素。\n",
    "print(torch.ge(x,y))\n",
    "#逐一比较tensor1中元素是否大于tensor2中元素。\n",
    "print(torch.gt(x,y))\n",
    "#逐一比较tensor1中元素是否小于等于tensor2中的元素\n",
    "print(torch.le(x,y))\n",
    "#逐一比较tensor1中元素是否小于tensor2中元素。\n",
    "print(torch.lt(x,y))\n",
    "#逐一比较两个tensor值不相等\n",
    "print(torch.ne(x,y))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 212,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor(9.)\n",
      "(tensor([5., 9.]), tensor([2, 2]))\n",
      "tensor(2.)\n",
      "(tensor([2., 3., 5.]), tensor([0, 0, 0]))\n"
     ]
    }
   ],
   "source": [
    "#最大值·最小值\n",
    "x = torch.Tensor([[2,3,5],[4,7,9]])\n",
    "print(torch.max(x))\n",
    "#若指定了dim，返回两个tensor，第一个tensor为指定维度上的最大值；第二个tensor为指定维度上对应最大值所在的索引\n",
    "print(torch.max(x,dim=1))\n",
    "#最小值\n",
    "print(torch.min(x))\n",
    "print(torch.min(x,dim=0))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 218,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(tensor([[ 3.,  5., 20.],\n",
      "        [ 4.,  9., 70.]]), tensor([[1, 2, 0],\n",
      "        [0, 2, 1]]))\n",
      "(tensor([[20., 70.,  9.],\n",
      "        [ 4.,  3.,  5.]]), tensor([[0, 1, 1],\n",
      "        [1, 0, 0]]))\n"
     ]
    }
   ],
   "source": [
    "#排序\n",
    "x = torch.Tensor([[20,3,5],[4,70,9]])\n",
    "#不指定dim，则默认按shape(-1)最后维度所在的方向进行升序排列；返回值第二个tensor为排序索引组成的tensor\n",
    "print(torch.sort(x))\n",
    "#指定descending关键字参数，设定升序还是降序\n",
    "print(torch.sort(x,dim=0,descending=True))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 224,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(tensor([[20.,  5.],\n",
      "        [70.,  9.]]), tensor([[0, 2],\n",
      "        [1, 2]]))\n",
      "(tensor([[3., 5.],\n",
      "        [4., 9.]]), tensor([[1, 2],\n",
      "        [0, 2]]))\n",
      "(tensor([[20., 70.,  9.],\n",
      "        [ 4.,  3.,  5.]]), tensor([[0, 1, 1],\n",
      "        [1, 0, 0]]))\n",
      "(tensor([[20.,  5.],\n",
      "        [70.,  9.]]), tensor([[0, 2],\n",
      "        [1, 2]]))\n"
     ]
    }
   ],
   "source": [
    "#topK选择最大的或者最小的K个元素作为返回值\n",
    "x = torch.Tensor([[20,3,5],[4,70,9]])\n",
    "#k关键字参数指定返回最大后最小的几个元素,默认取最大的元素\n",
    "print(torch.topk(x,k=2))\n",
    "#largest设置为false表示取最小的topk值\n",
    "print(torch.topk(x,k=2,largest=False))\n",
    "#指定dim关键字参数则表示沿着dim维度所在的方向取topk\n",
    "print(torch.topk(x,k=2,dim=0))\n",
    "print(torch.topk(x,k=2,dim=1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 376,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([19., 16., 21., 19., 19., 18., 15., 29., 23., 21.])\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXcAAAD8CAYAAACMwORRAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAAENFJREFUeJzt3W+sZHV9x/H3p7srmmhB3WvdLLteU/GBGgW8pRhtS9E2/As0ERNMq2AxmxCpYGwtaIKRR6CNNhYjWQspKlEsUrvqEsU/VH0AetkuK7haNwbDFlpWUJCo2NVvH8zYDsPcO+feO3fn7o/3K5ns+fO7Zz7M7vns4ew5Z1JVSJLa8lvTDiBJmjzLXZIaZLlLUoMsd0lqkOUuSQ2y3CWpQZa7JDXIcpekBlnuktSg9dN6440bN9bs7Oy03l6SDkt33HHHj6pqZty4qZX77Ows8/Pz03p7STosJflhl3GelpGkBlnuktQgy12SGmS5S1KDLHdJalDnck+yLsm/J/nciHVHJLkhyb4ktyeZnWRISdLSLOXI/SJg7wLrzgd+XFUvAD4AXLnSYJKk5etU7kmOBk4H/nGBIWcB1/WnbwRenSQrjydJWo6uR+5/D7wD+PUC6zcD9wJU1UHgYeDZK04nSVqWsXeoJjkDeKCq7khy0kLDRix7wjdvJ9kGbAPYunXrEmJq9pLPT+2977ni9Km9t6Tl6XLk/krgzCT3AJ8ETk7y8aEx+4EtAEnWA0cCDw1vqKq2V9VcVc3NzIx9NIIkaZnGlntVXVpVR1fVLHAO8JWq+ouhYTuAc/vTZ/fHPOHIXZJ0aCz7wWFJLgfmq2oHcA3wsST76B2xnzOhfJKkZVhSuVfVrcCt/enLBpb/AnjdJINJkpbPO1QlqUGWuyQ1yHKXpAZZ7pLUIMtdkhpkuUtSgyx3SWqQ5S5JDbLcJalBlrskNchyl6QGWe6S1CDLXZIaZLlLUoMsd0lqkOUuSQ2y3CWpQWPLPclTk3wzyZ1J7k7ynhFjzktyIMnu/uvNqxNXktRFl6/Zeww4uaoeTbIB+EaSm6vqtqFxN1TVhZOPKElaqrHlXlUFPNqf3dB/1WqGkiStTKdz7knWJdkNPADcUlW3jxj22iR7ktyYZMtEU0qSlqRTuVfVr6rqWOBo4IQkLxka8llgtqpeCnwJuG7UdpJsSzKfZP7AgQMryS1JWsSSrpapqp8AtwKnDC1/sKoe689+BHj5Aj+/varmqmpuZmZmGXElSV10uVpmJslR/emnAa8Bvjs0ZtPA7JnA3kmGlCQtTZerZTYB1yVZR+8vg09V1eeSXA7MV9UO4K1JzgQOAg8B561WYEnSeF2ultkDHDdi+WUD05cCl042miRpubxDVZIaZLlLUoMsd0lqkOUuSQ2y3CWpQZa7JDXIcpekBlnuktQgy12SGmS5S1KDLHdJapDlLkkNstwlqUGWuyQ1yHKXpAZZ7pLUIMtdkhrU5TtUn5rkm0nuTHJ3kveMGHNEkhuS7Etye5LZ1QgrSeqmy5H7Y8DJVfUy4FjglCQnDo05H/hxVb0A+ABw5WRjSpKWYmy5V8+j/dkN/VcNDTsLuK4/fSPw6iSZWEpJ0pJ0OueeZF2S3cADwC1VdfvQkM3AvQBVdRB4GHj2iO1sSzKfZP7AgQMrSy5JWlCncq+qX1XVscDRwAlJXjI0ZNRR+vDRPVW1varmqmpuZmZm6WklSZ0s6WqZqvoJcCtwytCq/cAWgCTrgSOBhyaQT5K0DF2ulplJclR/+mnAa4DvDg3bAZzbnz4b+EpVPeHIXZJ0aKzvMGYTcF2SdfT+MvhUVX0uyeXAfFXtAK4BPpZkH70j9nNWLbEkaayx5V5Ve4DjRiy/bGD6F8DrJhtNkrRc3qEqSQ2y3CWpQZa7JDXIcpekBlnuktQgy12SGmS5S1KDLHdJapDlLkkNstwlqUGWuyQ1yHKXpAZZ7pLUIMtdkhpkuUtSgyx3SWqQ5S5JDeryHapbknw1yd4kdye5aMSYk5I8nGR3/3XZqG1Jkg6NLt+hehB4e1XtSvIM4I4kt1TVd4bGfb2qzph8REnSUo09cq+q+6tqV3/6p8BeYPNqB5MkLd+SzrknmaX3Zdm3j1j9iiR3Jrk5yYsX+PltSeaTzB84cGDJYSVJ3XQu9yRPBz4NXFxVjwyt3gU8r6peBvwD8JlR26iq7VU1V1VzMzMzy80sSRqjU7kn2UCv2K+vqpuG11fVI1X1aH96J7AhycaJJpUkddblapkA1wB7q+r9C4x5bn8cSU7ob/fBSQaVJHXX5WqZVwJvAL6dZHd/2TuBrQBVdTVwNnBBkoPAz4FzqqpWIa8kqYOx5V5V3wAyZsxVwFWTCiVJWhnvUJWkBlnuktQgy12SGmS5S1KDLHdJapDlLkkNstwlqUGWuyQ1yHKXpAZZ7pLUIMtdkhpkuUtSgyx3SWqQ5S5JDbLcJalBlrskNchyl6QGdfkO1S1Jvppkb5K7k1w0YkySfDDJviR7khy/OnElSV10+Q7Vg8Dbq2pXkmcAdyS5paq+MzDmVOCY/uv3gQ/3f5UkTcHYI/equr+qdvWnfwrsBTYPDTsL+Gj13AYclWTTxNNKkjrpcuT+f5LMAscBtw+t2gzcOzC/v7/s/qGf3wZsA9i6devSkg6YveTzy/7ZlbrnitOn9t5PNv4+S8vX+R9Ukzwd+DRwcVU9Mrx6xI/UExZUba+quaqam5mZWVpSSVJnnco9yQZ6xX59Vd00Ysh+YMvA/NHAfSuPJ0laji5XywS4BthbVe9fYNgO4I39q2ZOBB6uqvsXGCtJWmVdzrm/EngD8O0ku/vL3glsBaiqq4GdwGnAPuBnwJsmH1WS1NXYcq+qbzD6nPrgmALeMqlQkqSV8Q5VSWqQ5S5JDbLcJalBlrskNchyl6QGWe6S1CDLXZIaZLlLUoMsd0lqkOUuSQ2y3CWpQZa7JDXIcpekBlnuktQgy12SGmS5S1KDunzN3rVJHkhy1wLrT0rycJLd/ddlk48pSVqKLl+z90/AVcBHFxnz9ao6YyKJJEkrNvbIvaq+Bjx0CLJIkiZkUufcX5HkziQ3J3nxhLYpSVqmLqdlxtkFPK+qHk1yGvAZ4JhRA5NsA7YBbN26dQJvLUkaZcVH7lX1SFU92p/eCWxIsnGBsduraq6q5mZmZlb61pKkBay43JM8N0n60yf0t/ngSrcrSVq+sadlknwCOAnYmGQ/8G5gA0BVXQ2cDVyQ5CDwc+CcqqpVSyxJGmtsuVfV68esv4repZKSpDXCO1QlqUGWuyQ1yHKXpAZZ7pLUIMtdkhpkuUtSgyx3SWqQ5S5JDbLcJalBlrskNchyl6QGWe6S1CDLXZIaZLlLUoMsd0lqkOUuSQ2y3CWpQWPLPcm1SR5IctcC65Pkg0n2JdmT5PjJx5QkLUWXI/d/Ak5ZZP2pwDH91zbgwyuPJUlaibHlXlVfAx5aZMhZwEer5zbgqCSbJhVQkrR0kzjnvhm4d2B+f3+ZJGlK1k9gGxmxrEYOTLbRO3XD1q1bJ/DW0uqYveTzU3nfe644fSrv+2Q0rd9jODS/z5M4ct8PbBmYPxq4b9TAqtpeVXNVNTczMzOBt5YkjTKJct8BvLF/1cyJwMNVdf8EtitJWqaxp2WSfAI4CdiYZD/wbmADQFVdDewETgP2AT8D3rRaYSVJ3Ywt96p6/Zj1BbxlYokkSSvmHaqS1CDLXZIaZLlLUoMsd0lqkOUuSQ2y3CWpQZa7JDXIcpekBlnuktQgy12SGmS5S1KDLHdJapDlLkkNstwlqUGWuyQ1yHKXpAZZ7pLUoE7lnuSUJN9Lsi/JJSPWn5fkQJLd/debJx9VktRVl+9QXQd8CPgTYD/wrSQ7quo7Q0NvqKoLVyGjJGmJuhy5nwDsq6ofVNUvgU8CZ61uLEnSSnQp983AvQPz+/vLhr02yZ4kNybZMpF0kqRl6VLuGbGshuY/C8xW1UuBLwHXjdxQsi3JfJL5AwcOLC2pJKmzLuW+Hxg8Ej8auG9wQFU9WFWP9Wc/Arx81IaqantVzVXV3MzMzHLySpI66FLu3wKOSfL8JE8BzgF2DA5Ismlg9kxg7+QiSpKWauzVMlV1MMmFwBeAdcC1VXV3ksuB+araAbw1yZnAQeAh4LxVzCxJGmNsuQNU1U5g59CyywamLwUunWw0SdJyeYeqJDXIcpekBlnuktQgy12SGmS5S1KDLHdJapDlLkkNstwlqUGWuyQ1yHKXpAZZ7pLUIMtdkhpkuUtSgyx3SWqQ5S5JDbLcJalBlrskNahTuSc5Jcn3kuxLcsmI9UckuaG//vYks5MOKknqbmy5J1kHfAg4FXgR8PokLxoadj7w46p6AfAB4MpJB5UkddflyP0EYF9V/aCqfgl8EjhraMxZwHX96RuBVyfJ5GJKkpaiS7lvBu4dmN/fXzZyTFUdBB4Gnj2JgJKkpVvfYcyoI/BaxhiSbAO29WcfTfK9Du9/KG0EfrTYgKytE05j807ChP6bD0nWCZpK3mV+1n62q2dVsq5wn3pel0Fdyn0/sGVg/mjgvgXG7E+yHjgSeGh4Q1W1HdjeJdg0JJmvqrlp5+jqcMp7OGWFwyvv4ZQVDq+8h1PWYV1Oy3wLOCbJ85M8BTgH2DE0Zgdwbn/6bOArVfWEI3dJ0qEx9si9qg4muRD4ArAOuLaq7k5yOTBfVTuAa4CPJdlH74j9nNUMLUlaXJfTMlTVTmDn0LLLBqZ/AbxustGmYs2eMlrA4ZT3cMoKh1fewykrHF55D6esjxPPnkhSe3z8gCQ16Elb7kmuTfJAkruGlv9V/1ELdyd577TyDRqVNcmxSW5LsjvJfJITpplxUJItSb6aZG//c7yov/xZSW5J8v3+r89cw1nfl+S7SfYk+ZckR007Kyycd2D9XyepJBunlXEgy4JZ19p+tsifgzW7n41VVU/KF/CHwPHAXQPL/hj4EnBEf/450865SNYvAqf2p08Dbp12zoFsm4Dj+9PPAP6D3qMr3gtc0l9+CXDlGs76p8D6/vIr10LWxfL257fQu/Dhh8DGtZp1Le5ni2Rds/vZuNeT9si9qr7GE6/FvwC4oqoe64954JAHG2GBrAX8dn/6SJ5478HUVNX9VbWrP/1TYC+9u5gHH1NxHfBn00n4/xbKWlVfrN7d1gC30bu/Y+oW+Wyh91yndzDiBsJpWCTrmtvPFsm6ZvezcZ605b6AFwJ/0H+y5b8l+b1pB1rExcD7ktwL/B1w6ZTzjNR/QuhxwO3A71TV/dDbmYDnTC/ZEw1lHfSXwM2HOs84g3mTnAn8Z1XdOdVQCxj6bNf0fjaU9bDYz0ax3B9vPfBM4ETgb4BPreEHoF0AvK2qtgBvo3evwZqS5OnAp4GLq+qRaedZzEJZk7wLOAhcP61sowzmpZfvXcBli/7QlIz4bNfsfjYi65rfzxZiuT/efuCm6vkm8Gt6z5ZYi84FbupP/zO9p3euGUk20NtJrq+q3+T87ySb+us3AVP/33FYMCtJzgXOAP68+idd14IReX8XeD5wZ5J76J1C2pXkudNL2bPAZ7sm97MFsq7p/WwxlvvjfQY4GSDJC4GnsHYfcHQf8Ef96ZOB708xy+P0j8KuAfZW1fsHVg0+puJc4F8PdbZhC2VNcgrwt8CZVfWzaeUbNipvVX27qp5TVbNVNUuvPI+vqv+aYtTF/hysuf1skaxrdj8ba9r/ojutF/AJ4H7gf+jtDOfT+0P2ceAuYBdw8rRzLpL1VcAdwJ30zg2+fNo5B/K+it4/RO0Bdvdfp9F7DPSX6e0gXwaetYaz7qP3GOvfLLt62lkXyzs05h7WxtUyC322a24/WyTrmt3Pxr28Q1WSGuRpGUlqkOUuSQ2y3CWpQZa7JDXIcpekBlnuktQgy12SGmS5S1KD/hcmyzvodj4TwAAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "#直方图操作\n",
    "#计算输入张量的直方图。以min和max为range边界，将其均分成bins个直条，然后将排序好的数据划分到各个直条(bins)中。\n",
    "#如果min和max都为0, 则利用数据中的最大最小值作为边界。\n",
    "x = torch.rand(200)\n",
    "y = torch.histc(x,min=0,max=1,bins=10)\n",
    "print(y)\n",
    "plt.hist(y)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 矩阵运算"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 251,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([0.5476, 0.7343])\n",
      "tensor([[0.5476, 0.0000],\n",
      "        [0.0000, 0.7343]])\n",
      "tensor([[0.0000, 0.0000, 0.5476, 0.0000],\n",
      "        [0.0000, 0.0000, 0.0000, 0.7343],\n",
      "        [0.0000, 0.0000, 0.0000, 0.0000],\n",
      "        [0.0000, 0.0000, 0.0000, 0.0000]])\n",
      "tensor([[0.0000, 0.0000, 0.0000],\n",
      "        [0.5476, 0.0000, 0.0000],\n",
      "        [0.0000, 0.7343, 0.0000]])\n"
     ]
    }
   ],
   "source": [
    "#对角矩阵\n",
    "a = torch.rand(2)\n",
    "print(a)\n",
    "#diag设置对角矩阵，diagnal等于0，设置主对角线\n",
    "x = torch.diag(a,diagonal=0)\n",
    "print(x)\n",
    "#diagnal大于0，设置主对角线之上diagnal对应位置的值\n",
    "x = torch.diag(a,diagonal=2)\n",
    "print(x)\n",
    "#diagnal小于0，设置主对角线之下diagnal对应的值\n",
    "x = torch.diag(a,diagonal=-1)\n",
    "print(x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 289,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[2.0000, 2.2200, 2.2220]])\n",
      "tensor([[0.6209, 0.6892, 0.6898]])\n",
      "tensor([[0.5371, 0.5962, 0.5967]])\n"
     ]
    }
   ],
   "source": [
    "#数据按维度规范\n",
    "x = torch.Tensor([[2,2.22,2.222]])\n",
    "print(x)\n",
    "#p指定p-norm范数，dim指定方向，maxnorm指定p-norm范数的最大值\n",
    "print(torch.renorm(x,p=1,dim=0,maxnorm=2))\n",
    "print(torch.renorm(x,p=2,dim=0,maxnorm=1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 294,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[-0.1204, -0.6686,  0.0600, -1.4658],\n",
      "        [ 1.5003, -0.8799,  2.1509,  0.2911]])\n",
      "tensor(-1.0003)\n"
     ]
    }
   ],
   "source": [
    "#矩阵的迹\n",
    "#二维矩阵主对角线上元素之和\n",
    "x = torch.randn(2,4)\n",
    "print(x)\n",
    "print(torch.trace(x))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 302,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[0.7409, 0.0000, 0.0000, 0.0000, 0.0000],\n",
      "        [0.8660, 0.2698, 0.0000, 0.0000, 0.0000]])\n",
      "tensor([[ 0.7409,  0.4724, -0.5798,  0.0000,  0.0000],\n",
      "        [ 0.8660,  0.2698,  0.7263, -0.8031,  0.0000]])\n",
      "tensor([[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],\n",
      "        [0.8660, 0.0000, 0.0000, 0.0000, 0.0000]])\n",
      "tensor([[ 0.7409,  0.4724, -0.5798, -0.2482,  0.9088],\n",
      "        [ 0.0000,  0.2698,  0.7263, -0.8031, -1.6912]])\n",
      "tensor([[0.0000, 0.0000, 0.0000, 0.0000, 0.9088],\n",
      "        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000]])\n",
      "tensor([[ 0.7409,  0.4724, -0.5798, -0.2482,  0.9088],\n",
      "        [ 0.8660,  0.2698,  0.7263, -0.8031, -1.6912]])\n"
     ]
    }
   ],
   "source": [
    "#下三角矩阵\n",
    "#参数diagonal控制对角线: diagonal = 0, 主对角线 ;diagonal > 0, 主对角线之上 ;diagonal < 0, 主对角线之下\n",
    "x = torch.randn(2,5)\n",
    "print(torch.tril(x))\n",
    "print(torch.tril(x,diagonal=2))\n",
    "print(torch.tril(x,diagonal=-1))\n",
    "#上三角矩阵\n",
    "#参数diagonal控制对角线: diagonal = 0, 主对角线 ;diagonal > 0, 主对角线之上 ;diagonal < 0, 主对角线之下\n",
    "print(torch.triu(x))\n",
    "print(torch.triu(x,diagonal=4))\n",
    "print(torch.triu(x,diagonal=-1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 375,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([1, 2, 3])\n",
      "torch.Size([1, 3, 1])\n",
      "tensor([[[ 46.],\n",
      "         [118.]]]) tensor([ 46., 118.])\n",
      "tensor([[ 46.],\n",
      "        [118.]])\n",
      "tensor(20.)\n"
     ]
    }
   ],
   "source": [
    "#bmm矩阵乘积\n",
    "#矩阵A的列数需等于矩阵B的行数\n",
    "#1*2*3\n",
    "x = torch.Tensor([[[1,2,3],[4,5,6]]])\n",
    "print(x.shape)\n",
    "#1*3*1\n",
    "y = torch.Tensor([[[9],[8],[7]]])\n",
    "print(y.shape)\n",
    "#res 1*2*1\n",
    "print(torch.bmm(x,y),torch.bmm(x,y).squeeze(0).squeeze(1))\n",
    "#注意和mm的区别bmm为batch matrix multiplication,二mm维matrix multiplication\n",
    "print(torch.mm(x.squeeze(0),y.squeeze(0)))\n",
    "#计算两个一维张量的点积torch.dot,两个向量对应位置相乘再相加\n",
    "x = torch.Tensor([1,2,3,4])\n",
    "y = torch.Tensor([4,3,2,1])\n",
    "print(torch.dot(x,y))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 323,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([1, 2])\n",
      "torch.Size([1, 3])\n",
      "torch.Size([3, 2])\n",
      "tensor([[110.1000, 140.2000]])\n"
     ]
    }
   ],
   "source": [
    "#矩阵相乘再相加\n",
    "#addmm方法用于两个矩阵相乘结果再加到M矩阵，用beta调节M矩阵权重；用alpha调节矩阵乘积结果系数。\n",
    "#两个相乘的矩阵维度为2，分别表示[width，length]，第一个矩阵的length应该等于第二个矩阵的width满足矩阵相乘条件\n",
    "#out=(beta∗M)+(alpha∗mat1·mat2)\n",
    "x = torch.Tensor([[1,2]])\n",
    "print(x.shape)\n",
    "#batch1:1*1*3\n",
    "batch1 = torch.Tensor([[1,2,3]])\n",
    "print(batch1.shape)\n",
    "#batch2:1*3*2\n",
    "batch2 = torch.Tensor([[1,2],[3,4],[5,6]])\n",
    "print(batch2.shape)\n",
    "print(torch.addmm(x,batch1,batch2,beta=0.1,alpha=5))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 324,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([1, 2])\n",
      "torch.Size([1, 1, 3])\n",
      "torch.Size([1, 3, 2])\n",
      "tensor([[220.1000, 280.2000]])\n"
     ]
    }
   ],
   "source": [
    "#批矩阵相乘再相加\n",
    "#addbmm方法用于批矩阵相乘结果再加到M矩阵，用beta调节M矩阵权重；用alpha调节矩阵乘积结果系数。\n",
    "#两个相乘的矩阵维度为3，分别表示[batchsize,width，length]，第一个矩阵的length应该等于第二个矩阵的width满足矩阵相乘条件\n",
    "#两个相乘矩阵batchsize应该相等。\n",
    "#res=(beta∗M)+(alpha∗sum(batch1i·batch2i,i=0,b))\n",
    "#1*2\n",
    "x = torch.Tensor([[1,2]])\n",
    "print(x.shape)\n",
    "#batch1:1*1*3\n",
    "batch1 = torch.Tensor([[[1,2,3]]])\n",
    "print(batch1.shape)\n",
    "#batch2:1*3*2\n",
    "batch2 = torch.Tensor([[[1,2],[3,4],[5,6]]])\n",
    "print(batch2.shape)\n",
    "print(torch.addbmm(x,batch1,batch2,beta=0.1,alpha=10))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 352,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([3])\n",
      "torch.Size([3, 1])\n",
      "torch.Size([1])\n",
      "tensor([130., 260., 390.])\n"
     ]
    }
   ],
   "source": [
    "#矩阵乘向量再相加\n",
    "#addmv方法用于矩阵和向量相乘结果再加到M矩阵，用beta调节M矩阵权重；用alpha调节矩阵同向量乘积结果。\n",
    "#矩阵的列应该等于向量长度以满足相乘条件\n",
    "# out=(beta∗tensor)+(alpha∗(mat·vec))\n",
    "#1*3\n",
    "x = torch.Tensor([1,2,3])\n",
    "print(x.shape)\n",
    "#batch1:1*3\n",
    "mat = torch.Tensor([[1],[2],[3]])\n",
    "print(mat.shape)\n",
    "#batch2:\n",
    "vec = torch.Tensor([3])\n",
    "print(vec.shape)\n",
    "print(torch.addmv(x,mat,vec,beta=100,alpha=10))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 373,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(tensor([[19.0221,  0.0000],\n",
      "        [ 6.1871,  0.0000],\n",
      "        [-2.2092,  0.0000]]), tensor([]))\n",
      "(tensor([[19.0221,  0.0000],\n",
      "        [ 6.1871,  0.0000],\n",
      "        [-2.2092,  0.0000]]), tensor([[ 0.3385,  0.7846, -0.0528],\n",
      "        [ 0.5373, -0.4213, -0.7286],\n",
      "        [ 0.7725, -0.4548,  0.6829]]))\n"
     ]
    }
   ],
   "source": [
    "#计算方阵的特征值和特征向量(eigenvector)\n",
    "#eigenvectors (bool) – 布尔值，如果True，则同时计算特征值和特征向量，否则只计算特征值。\n",
    "x = torch.Tensor([[9,2,3],[4,5,8],[7,10,9]])\n",
    "print(torch.eig(x))\n",
    "print(torch.eig(x,eigenvectors=True))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
